# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the Reshape layer."""
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras.engine.base_layer import Layer
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.layers.Reshape')
class Reshape(Layer):
  """Layer that reshapes inputs into the given shape.

  Input shape:
    Arbitrary, although all dimensions in the input shape must be known/fixed.
    Use the keyword argument `input_shape` (tuple of integers, does not include
    the samples/batch size axis) when using this layer as the first layer
    in a model.

  Output shape:
    `(batch_size,) + target_shape`

  Example:

  >>> # as first layer in a Sequential model
  >>> model = tf.keras.Sequential()
  >>> model.add(tf.keras.layers.Reshape((3, 4), input_shape=(12,)))
  >>> # model.output_shape == (None, 3, 4), `None` is the batch size.
  >>> model.output_shape
  (None, 3, 4)

  >>> # as intermediate layer in a Sequential model
  >>> model.add(tf.keras.layers.Reshape((6, 2)))
  >>> model.output_shape
  (None, 6, 2)

  >>> # also supports shape inference using `-1` as dimension
  >>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
  >>> model.output_shape
  (None, 3, 2, 2)
  """

  def __init__(self, target_shape, **kwargs):
    """Creates a `tf.keras.layers.Reshape`  layer instance.

    Args:
      target_shape: Target shape. Tuple of integers, does not include the
        samples dimension (batch size).
      **kwargs: Any additional layer keyword arguments.
    """
    super(Reshape, self).__init__(**kwargs)
    self.target_shape = tuple(target_shape)

  def _fix_unknown_dimension(self, input_shape, output_shape):
    """Find and replace a missing dimension in an output shape.

    This is a near direct port of the internal Numpy function
    `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`

    Args:
      input_shape: Shape of array being reshaped
      output_shape: Desired shape of the array with at most a single -1 which
        indicates a dimension that should be derived from the input shape.

    Returns:
      The new output shape with a -1 replaced with its computed value.

    Raises:
      ValueError: If the total array size of the output_shape is
      different than the input_shape, or more than one unknown dimension
      is specified.
    """
    output_shape = list(output_shape)
    msg = ('total size of new array must be unchanged, '
           'input_shape = {}, output_shape = {}'.format(input_shape,
                                                        output_shape))

    known, unknown = 1, None
    for index, dim in enumerate(output_shape):
      if dim < 0:
        if unknown is None:
          unknown = index
        else:
          raise ValueError(
              f'There must be at most one unknown dimension in output_shape. '
              f'Received: output_shape={output_shape}.')
      else:
        known *= dim

    original = np.prod(input_shape, dtype=int)
    if unknown is not None:
      if known == 0 or original % known != 0:
        raise ValueError(msg)
      output_shape[unknown] = original // known
    elif original != known:
      raise ValueError(msg)
    return output_shape

  def compute_output_shape(self, input_shape):
    input_shape = tf.TensorShape(input_shape).as_list()
    if None in input_shape[1:]:
      output_shape = [input_shape[0]]
      # input shape (partially) unknown? replace -1's with None's
      output_shape += tuple(s if s != -1 else None for s in self.target_shape)
    else:
      output_shape = [input_shape[0]]
      output_shape += self._fix_unknown_dimension(input_shape[1:],
                                                  self.target_shape)
    return tf.TensorShape(output_shape)

  def call(self, inputs):
    result = tf.reshape(inputs, (tf.shape(inputs)[0],) + self.target_shape)
    if not tf.executing_eagerly():
      # Set the static shape for the result since it might lost during array_ops
      # reshape, eg, some `None` dim in the result could be inferred.
      result.set_shape(self.compute_output_shape(inputs.shape))
    return result

  def get_config(self):
    config = {'target_shape': self.target_shape}
    base_config = super(Reshape, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
